#!/usr/bin/env python3

import struct
import sys
import os
import tempfile
import subprocess


def die(errmsg):
    print(errmsg)
    exit(1)


def align(s, size):
    if len(s) % size != 0:
        s += bytes('\0' * (size - len(s) % size), 'utf-8')
    return s


def get_temp_dsk_name():
    with tempfile.NamedTemporaryFile(mode='wb') as disk_copy:
        return disk_copy.name + '.dsk'


class Pintos(object):
    def __init__(self, ttest=False, mem=256, no_vga=True, serial=False,
                 args=[], mnts=[], hostfns=[], guestfns=[], gdb=False,
                 fs='fs.dsk', swap='swap.dsk', timeout=0):
        self.ttest = ttest
        self.mem = mem
        self.no_vga = no_vga
        self.args = args
        self.gdb = gdb
        self.proc = None
        self.timeout = timeout
        self.host_fns = hostfns
        self.guest_fns = guestfns
        self.mnts = mnts
        self.bdevs = {'os': 'os.dsk', 'fs': fs, 'swap': swap}

    def __scan_dir(self):
        new = {}
        for k, v in self.bdevs.items():
            if not os.path.exists(v):
                try:
                    size = int(v)
                    new[k] = get_temp_dsk_name()
                    with open(new[k], 'wb') as f:
                        f.write(bytes('\0' * (0xfc000 * size), 'utf-8'))
                except Exception:
                    if k == 'os':
                        die('os.dsk cannot be temporal.')
            else:
                new[k] = v
        return new

    def __prepare_scratch_files(self):
        puts = []
        gets = []
        self.bdevs['scratch'] = get_temp_dsk_name()
        disk = open(self.bdevs['scratch'], 'wb')
        for fname in self.host_fns:
            host = fname[0]
            puts.append(fname[1] if len(fname) > 1 else host)
            with open(host, 'rb') as hf:
                data = hf.read()
            disk.write(bytes("PUT\0", 'utf-8') +
                       struct.pack("<I", len(data)) +
                       bytes("\0" * 504, 'utf-8'))
            disk.write(align(data, 512))

        for fname in self.guest_fns:
            disk.write(bytes("\0" * 0x100000, 'utf-8'))
            gets.append(fname)

        disk.close()
        return puts, gets

    def __prepare_kernel_argument(self, puts, gets):
        rem = []
        args = []
        for idx, arg in enumerate(self.args):
            if arg[0] != '-':
                rem = self.args[idx:]
                break
            else:
                args.append(arg)

        for put in puts:
            args.extend(['put', put])

        args.extend(rem)

        for get in gets:
            args.extend(['get', get[0]])

        cmd = ''.join('{}\0'.format(c) for c in args)
        if len(cmd) > 128:
            die("command line exceeds 128 bytes")

        with tempfile.NamedTemporaryFile(mode='wb') as disk_copy:
            name = disk_copy.name + '.dsk'

        with open('os.dsk', 'rb') as f:
            data = f.read()

        with open(name, 'wb+') as f:
            f.write(data[:0x17a] +
                    struct.pack("<I", len(args)) +
                    bytes(cmd.ljust(128, '\0'), 'utf-8') +
                    data[0x1fe:])
        return name

    def __prepare_cmd(self):
        cmd = ['qemu-system-x86_64']
        if self.no_vga:
            cmd.append('-nographic')
        if self.gdb:
            cmd.extend(['-s', '-S'])

        for idx, d in enumerate(['os', 'fs', 'scratch', 'swap']):
            if self.bdevs.get(d, None):
                cmd.extend(['-drive',
                            'file={},format=raw,index={},media=disk'
                            .format(self.bdevs[d], idx)])
        for idx, mnt in enumerate(self.mnts):
            cmd.extend(['-drive',
                        'file={},format=raw,index={},media=disk'
                        .format(mnt, 4 + idx)])

        cmd.extend(['-cpu', 'qemu64'])
        cmd.extend(['-m', str(self.mem)])
        cmd.extend(['-no-reboot'])
        # cmd.extend(['-enable-kvm']) # Sadly, kvm is not available on server.
        cmd.extend(['-serial', 'mon:stdio'])
        return cmd

    def get_files(self, gets):
        # get files.
        if gets:
            with open(self.bdevs['scratch'], 'rb') as f:
                for get in gets:
                    if f.read(4) != b'GET\0':
                        print('bad signature on scratch disk')
                    else:
                        size = struct.unpack("<I", f.read(4))[0]
                        f.read(504)  # skip to the next sector
                        with open(get[1], 'wb') as g:  # Copy file data.
                            g.write(f.read(size))
                        # Skip forward in disk up to beginning of next sector.
                        if size % 512 != 0:
                            size += (512 - size % 512)

    def run(self):
        self.bdevs = self.__scan_dir()
        puts, gets = (self.__prepare_scratch_files()
                      if self.host_fns or self.guest_fns else ([], []))

        self.bdevs['os'] = self.__prepare_kernel_argument(puts, gets)
        cmd = self.__prepare_cmd()
        args = {'stdin': sys.stdin, 'stdout': sys.stdout, 'stderr': sys.stderr}
        if self.timeout != 0:
            args['timeout'] = self.timeout
        try:
            subprocess.run(cmd, **args)
        except subprocess.TimeoutExpired:
            sys.stdout.write("TIMEOUT")
        finally:
            self.get_files(gets)
            for k, bdev in self.bdevs.items():  # delete temporal disk file
                if os.path.exists(bdev) and bdev.startswith("/tmp"):
                    os.remove(bdev)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(
            description='a utility for running Pintos in a simulator')
    parser.add_argument('-v', '--no-vga', action='store_true', default=True,
                        help='No VGA display or keyboard')
    parser.add_argument('-k', '--kill-on-failure', action='store_true',
                        help='Kill Pintos a few seconds after a kernel or user'
                        'panic, test failure, or triple fault (deprecated)')
    parser.add_argument('-T', '--timeout', type=int, default=0,
                        help='Kill Pintos after N seconds CPU time')

    parser.add_argument('-m', '--memory', type=int, default=256,
                        help='memory capacity')
    parser.add_argument('--fs-disk', default='fs.dsk',
                        help='Set FS disk file or size')
    parser.add_argument('--swap-disk', default='swap.dsk',
                        help='Set SWAP disk file or size')
    parser.add_argument('-p', '--put-file', dest='HOSTFNS', nargs=1,
                        action='append', default=[],
                        help='Copy HOSTFN into VM, splited by ":".'
                             ' (e.g. tests/userprog/args-none:args-none')
    parser.add_argument('-g', '--get-file', dest='GUESTFNS', nargs=1,
                        action='append', default=[],
                        help='Copy GUESTFN out of VM, '
                             'by default under same name')
    parser.add_argument('--mnts', dest='MNTS', nargs=1,
                        action='append', default=[],
                        help='Additional mounting disks')
    parser.add_argument('--gdb', action='store_true', default=False,
                        help='Debug with gdb')
    parser.add_argument('-t', '--threads-tests', action='store_true',
                        default=False,
                        help='Run proj1 test cases with USERPROG flag')

    if '--' in sys.argv:
        pintos_arg_index = sys.argv.index('--')
        util_args = sys.argv[1: pintos_arg_index]
        kern_args = sys.argv[pintos_arg_index + 1:]
    else:
        util_args = sys.argv[1:]
        kern_args = []

    args = parser.parse_args(util_args)
    Pintos(ttest=args.threads_tests, mem=args.memory, no_vga=args.no_vga,
           args=kern_args, timeout=args.timeout, fs=args.fs_disk, gdb=args.gdb,
           swap=args.swap_disk,
           mnts=[f[0] for f in args.MNTS],
           hostfns=[f[0].split(':') for f in args.HOSTFNS],
           guestfns=[f[0].split(':') for f in args.GUESTFNS]).run()
